// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package uki_test

import (
	"crypto"
	"crypto/rsa"
	stdx509 "crypto/x509"
	"debug/pe"
	"encoding/pem"
	"fmt"
	"io"
	"os"
	"os/exec"
	"path/filepath"
	"strings"
	"testing"
	"time"

	"github.com/siderolabs/crypto/x509"
	"github.com/stretchr/testify/require"

	"github.com/siderolabs/talos/internal/pkg/measure"
	"github.com/siderolabs/talos/internal/pkg/uki"
	"github.com/siderolabs/talos/pkg/machinery/config/generate/secrets"
	"github.com/siderolabs/talos/pkg/machinery/constants"
	"github.com/siderolabs/talos/pkg/machinery/imager/quirks"
	"github.com/siderolabs/talos/pkg/machinery/version"
	"github.com/siderolabs/talos/pkg/splash"
)

type certificateProvider struct {
	*x509.CertificateAuthority
}

func (c *certificateProvider) Signer() crypto.Signer {
	return c.CertificateAuthority.Key.(crypto.Signer)
}

func (c *certificateProvider) Certificate() *stdx509.Certificate {
	return c.CertificateAuthority.Crt
}

type rsaWrapper struct {
	*rsa.PrivateKey
}

func (w rsaWrapper) PublicRSAKey() *rsa.PublicKey {
	return &w.PrivateKey.PublicKey
}

func loadRSAKey(path string) (measure.RSAKey, error) {
	keyData, err := os.ReadFile(path)
	if err != nil {
		return nil, err
	}

	// convert private key to rsa.PrivateKey
	rsaPrivateKeyBlock, _ := pem.Decode(keyData)
	if rsaPrivateKeyBlock == nil {
		return nil, err
	}

	rsaKey, err := stdx509.ParsePKCS1PrivateKey(rsaPrivateKeyBlock.Bytes)
	if err != nil {
		return nil, fmt.Errorf("parse private key failed: %v", err)
	}

	return rsaWrapper{rsaKey}, nil
}

// TestBuildUKI tests the UKI build process for both normal and signed UKI with and without multi profile ukis.
// It uses the UKI build process from the uki package and the ukify tool to
// generate a UKI file and then compares the generated UKI file with the UKI
// file generated by the ukify tool.
// This test requires a patched version of ukify that allows `enter-machined` as a phase and drop the trailing null character from sbat
// since Talos UKI generation doesn't update sbat
// diff --git src/ukify/ukify.py src/ukify/ukify.py
// index 639301bdb6..0ab9394c2b 100755
// --- src/ukify/ukify.py
// +++ src/ukify/ukify.py
// @@ -622,6 +622,7 @@ def parse_banks(s: str) -> list[str]:
//
//		KNOWN_PHASES = (
//		    'enter-initrd',
//		    'leave-initrd',
//	  - 'enter-machined',
//	    'sysinit',
//	    'ready',
//	    'shutdown',
//
// @@ -1046,7 +1047,7 @@ def merge_sbat(input_pe: list[Path], input_text: list[str]) -> str:
//
//		return (
//		    'sbat,1,SBAT Version,sbat,1,https://github.com/rhboot/shim/blob/main/SBAT.md\n'
//		    + '\n'.join(sbat)
//	  - + '\n\x00'
//	  - + '\n'
//	    )
func TestBuildUKI(t *testing.T) {
	requiredTools := []string{
		"ukify",
		"sbverify",
		"sbsign",
		"systemd-measure",
	}

	allToolsFound := true

	for _, tool := range requiredTools {
		if _, err := exec.LookPath(tool); err != nil {
			t.Logf("tool %s not found", tool)

			allToolsFound = false
		}
	}

	if !allToolsFound {
		t.Skipf("skipping test as some required tools %s not found", strings.Join(requiredTools, ", "))
	}

	currentTime := time.Now()

	opts := []x509.Option{
		x509.RSA(true),
		x509.Bits(2048),
		x509.CommonName("test-sign"),
		x509.NotAfter(currentTime.Add(secrets.CAValidityTime)),
		x509.NotBefore(currentTime),
		x509.Organization("test-sign"),
	}

	signingKey, err := x509.NewSelfSignedCertificateAuthority(opts...)
	require.NoError(t, err)

	pcrSigningKey := "../measure/testdata/pcr-signing-key.pem"

	rsaKey, err := loadRSAKey(pcrSigningKey)
	if err != nil {
		t.Fatal(err)
	}

	tempDir := t.TempDir()

	publicKeyBytes, err := stdx509.MarshalPKIXPublicKey(rsaKey.PublicRSAKey())
	require.NoError(t, err)

	publicKeyPEM := pem.EncodeToMemory(&pem.Block{
		Type:  x509.PEMTypeRSAPublic,
		Bytes: publicKeyBytes,
	})

	pcrPubKey := filepath.Join(tempDir, "pcr-public.pem")

	require.NoError(t, os.WriteFile(pcrPubKey, publicKeyPEM, 0o600))

	sdBootSigned := filepath.Join(tempDir, "sd-boot-signed.efi")
	kernel := "testdata/kernel"
	initrd := "testdata/kernel" // we don't have a real initrd file, so we use the kernel file
	cmdline := "console=ttyS0"
	splashImage := filepath.Join(tempDir, "splash")
	osReleaseFile := filepath.Join(tempDir, "os-release")
	secureBootSingingKey := filepath.Join(tempDir, "sb-key.pem")
	secureBootCertificate := filepath.Join(tempDir, "sb-cert.pem")
	sdStubPath := "internal/pe/testdata/linuxx64.efi.stub"
	addonStubPath := "internal/pe/testdata/addonx64.efi.stub"

	ukiProfiles := []uki.Profile{
		{
			ID:    "reset-maintenance",
			Title: "Reset to maintenance mode",

			Cmdline: cmdline + fmt.Sprintf(" %s=system:EPHEMERAL,STATE", constants.KernelParamWipe),
		},
		{
			ID:      "reset",
			Title:   "Reset system disk",
			Cmdline: cmdline + fmt.Sprintf(" %s=system", constants.KernelParamWipe),
		},
	}

	for _, talosVersion := range []string{"1.9.0", "1.10.0"} {
		ukiUnsigned := filepath.Join(tempDir, fmt.Sprintf("uki-%s.efi", talosVersion))
		ukiSigned := filepath.Join(tempDir, fmt.Sprintf("uki-%s-signed.efi", talosVersion))
		ukifyUKIUnsigned := filepath.Join(tempDir, fmt.Sprintf("uki-ukify-%s.efi", talosVersion))
		ukifyUKISigned := filepath.Join(tempDir, fmt.Sprintf("uki-ukify-signed-%s.efi", talosVersion))

		builder := &uki.Builder{
			Arch:       "amd64",
			Version:    talosVersion,
			SdStubPath: sdStubPath,
			SdBootPath: sdStubPath, // this doesn't matter for the test
			KernelPath: kernel,
			InitrdPath: initrd,
			Cmdline:    cmdline,
			Profiles:   ukiProfiles,

			OutSdBootPath: sdBootSigned,
			OutUKIPath:    ukiUnsigned,
		}

		require.NoError(t, builder.Build(t.Logf))

		signedBuilder := &uki.Builder{
			Arch:             "amd64",
			Version:          talosVersion,
			SdStubPath:       sdStubPath,
			SdBootPath:       sdStubPath, // this doesn't matter for the test
			KernelPath:       kernel,
			InitrdPath:       initrd,
			Cmdline:          cmdline,
			SecureBootSigner: &certificateProvider{signingKey},
			PCRSigner:        rsaKey,
			Profiles:         ukiProfiles,

			OutSdBootPath: sdBootSigned,
			OutUKIPath:    ukiSigned,
		}

		require.NoError(t, signedBuilder.BuildSigned(t.Logf))

		require.NoError(t, os.WriteFile(splashImage, splash.GetBootImage(), 0o600))

		osRelease, err := version.OSReleaseFor(version.Name, builder.Version)
		require.NoError(t, err)

		require.NoError(t, os.WriteFile(osReleaseFile, osRelease, 0o600))

		kernelVersion, err := uki.DiscoverKernelVersion(builder.KernelPath)
		require.NoError(t, err)

		require.NoError(t, os.WriteFile(secureBootCertificate, signingKey.CrtPEM, 0o600))
		require.NoError(t, os.WriteFile(secureBootSingingKey, signingKey.KeyPEM, 0o600))

		ukifyCmdArgs := []string{
			"build",
			"--stub",
			builder.SdStubPath,
			"--sbat",
			"sbat,", // this is a hack so that ukify doesn't add extra sbat info
			"--splash",
			splashImage,
			"--linux",
			kernel,
			"--initrd",
			initrd,
			"--os-release",
			"@" + osReleaseFile,
			"--uname",
			kernelVersion,
			"--cmdline",
			cmdline,
			"--output",
			ukifyUKIUnsigned,
		}

		ukifySignedCmdArgs := []string{
			"build",
			"--stub",
			signedBuilder.SdStubPath,
			"--sbat",
			"sbat,", // this is a hack so that ukify doesn't add extra sbat info
			"--splash",
			splashImage,
			"--linux",
			kernel,
			"--initrd",
			initrd,
			"--os-release",
			"@" + osReleaseFile,
			"--uname",
			kernelVersion,
			"--cmdline",
			cmdline,
			"--pcr-private-key",
			pcrSigningKey,
			"--pcrpkey",
			pcrPubKey,
			"--secureboot-private-key",
			secureBootSingingKey,
			"--secureboot-certificate",
			secureBootCertificate,
			"--pcr-banks",
			"sha256,sha384,sha512",
			"--phases",
			"enter-initrd:leave-initrd:enter-machined",
			"--output",
			ukifyUKISigned,
		}

		if quirks.New(talosVersion).SupportsUKIProfiles() {
			resetMaintenanceProfile := filepath.Join(tempDir, "profile0.efi")
			resetProfile := filepath.Join(tempDir, "profile1.efi")

			ukifyResetMaintenanceProfileCmd := exec.CommandContext(
				t.Context(),
				"ukify",
				[]string{
					"build",
					"--stub",
					addonStubPath,
					"--profile",
					"ID=reset-maintenance\nTITLE=Reset to maintenance mode",
					"--cmdline",
					cmdline + fmt.Sprintf(" %s=system:EPHEMERAL,STATE", constants.KernelParamWipe),
					"--output",
					resetMaintenanceProfile,
				}...,
			)

			ukifyResetMaintenanceProfileCmd.Stderr = os.Stderr
			ukifyResetMaintenanceProfileCmd.Stdout = os.Stdout

			t.Log("Running ukify command:", ukifyResetMaintenanceProfileCmd.String())

			require.NoError(t, ukifyResetMaintenanceProfileCmd.Run())

			ukifyResetProfileCmd := exec.CommandContext(
				t.Context(),
				"ukify",
				[]string{
					"build",
					"--stub",
					addonStubPath,
					"--profile",
					"ID=reset\nTITLE=Reset system disk",
					"--cmdline",
					cmdline + fmt.Sprintf(" %s=system", constants.KernelParamWipe),
					"--output",
					resetProfile,
				}...,
			)

			ukifyResetProfileCmd.Stderr = os.Stderr
			ukifyResetProfileCmd.Stdout = os.Stdout

			t.Log("Running ukify command:", ukifyResetProfileCmd.String())

			require.NoError(t, ukifyResetProfileCmd.Run())

			ukifyCmdArgs = append(ukifyCmdArgs,
				"--join-profile",
				resetMaintenanceProfile,
				"--join-profile",
				resetProfile,
			)

			ukifySignedCmdArgs = append(ukifySignedCmdArgs,
				"--join-profile",
				resetMaintenanceProfile,
				"--join-profile",
				resetProfile,
			)
		}

		ukifyCmd := exec.CommandContext(
			t.Context(),
			"ukify",
			ukifyCmdArgs...,
		)

		ukifyCmd.Stderr = os.Stderr
		ukifyCmd.Stdout = os.Stdout

		t.Log("Running ukify command:", ukifyCmd.String())

		require.NoError(t, ukifyCmd.Run())

		compareUKIFiles(t, ukiUnsigned, ukifyUKIUnsigned)

		ukifySignedCmd := exec.CommandContext(
			t.Context(),
			"ukify",
			ukifySignedCmdArgs...,
		)

		ukifySignedCmd.Stderr = os.Stderr
		ukifySignedCmd.Stdout = os.Stdout

		t.Log("Running ukify command:", ukifySignedCmd.String())

		require.NoError(t, ukifySignedCmd.Run())

		compareUKIFiles(t, ukiSigned, ukifyUKISigned)
	}
}

func compareUKIFiles(t *testing.T, ukiNative, ukiUKIFY string) {
	ukiData, err := pe.Open(ukiNative)
	require.NoError(t, err)

	t.Cleanup(func() {
		require.NoError(t, ukiData.Close())
	})

	ukiDataUKIFY, err := pe.Open(ukiUKIFY)
	require.NoError(t, err)

	t.Cleanup(func() {
		require.NoError(t, ukiDataUKIFY.Close())
	})

	if len(ukiData.Sections) != len(ukiDataUKIFY.Sections) {
		t.Fatalf("sections count mismatch: %d != %d", len(ukiData.Sections), len(ukiDataUKIFY.Sections))
	}

	for i, section := range ukiData.Sections {
		sectionUKIFY := ukiDataUKIFY.Sections[i]

		if section.Name != sectionUKIFY.Name {
			t.Fatalf("section name mismatch: %s != %s", section.Name, sectionUKIFY.Name)
		}

		sectionReader := io.LimitReader(section.Open(), int64(section.VirtualSize))

		sectionReaderUKIFY := io.LimitReader(sectionUKIFY.Open(), int64(sectionUKIFY.VirtualSize))

		var sectionData, sectionDataUKIFY strings.Builder

		_, err := io.Copy(&sectionData, sectionReader)
		require.NoError(t, err)

		_, err = io.Copy(&sectionDataUKIFY, sectionReaderUKIFY)
		require.NoError(t, err)

		expected := sectionData.String()
		actual := sectionDataUKIFY.String()

		if section.Name == ".pcrsig" {
			expected = strings.ReplaceAll(expected, " ", "")
			actual = strings.ReplaceAll(actual, " ", "")
		}

		require.Equal(t, expected, actual, "section %s at index %d differs", section.Name, i)
	}
}
